import numpy as np
from numpy import *

## Compute matrix R (Eq. 3 in the paper)
def compute_matrix(p, lamda_1, lamda_2, alpha):
    '''
    lamda_1 and lamda_2 are trade-offs between entropy and diversity
    lamda_1: turn on/off entropy (1 or 0)
    lamda_2: turn on/off diversity (1 or 0)
    Handling noisy labels via majority selection (see Sec 3.3)
    '''
    top = round(len(p) * alpha)
    p = Majority(p, top)

    # Initialize the R matrix
    R = np.zeros((len(p), len(p)))
    for i in range(0,len(p)):
        for j in range(0,i+1):
            if i==j:
                R[i,i]=- lamda_1 * (p[i]*log(p[i])+(1-p[i])*log(1-p[i]))
                temp=sum(R)
            else:
                R[i,j] = lamda_2 * ((p[i] - p[j]) * log((p[i]) / (p[j])) \
                                    + (1 - p[i] - 1 + p[j]) * log((1 - p[i]) / (1 - p[j])))
                R[j,i] = R[i,j]
                temp=sum(R)
    S=sum(R)
    return S

# Sec 3.3. Handling noisy labels via majority selection
# eq. 4, calculate the mean and sort based on dominate predictions.
def Majority(p,top):
    m=np.mean(p)
    if m>0.5:
        p = sorted(p,reverse=True)
    else:
        p=sorted(p,reverse=False)
    p=p[0:top]
    return p

def test_entroy():
    A = [0.4, 0.4, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.6, 0.6]
    B = [0.0, 0.1, 0.2, 0.3, 0.4, 0.4, 0.6, 0.7, 0.8, 1.0, 1.0]
    C = [0.0, 0.0, 0.0, 0.1, 0.1, 0.9, 0.9, 1.0, 1.0, 1.0, 1.0]
    D = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9, 0.9, 0.9, 0.9]
    E = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.1]
    F  = [0.0, 0.0, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.3, 0.9, 1.0]
    G  = [0.0, 0.1, 0.7, 0.8, 0.8, 0.9, 0.9, 0.9, 0.9, 1.0, 1.0]
    ### Produce values in Table 1 by row pattern A
    disp('Table 1. pattern A')
    disp(['Entropy_A = ', round(compute_matrix(A, 1, 0, 1), 2)])
    disp(['Entropy^1/4_A = ', round(compute_matrix(A, 1, 0, 0.25), 2)])
    disp(['Diversity_A = ', round(compute_matrix(A, 0, 1, 1), 2)])
    disp(['Diversity^1/4_A = ', round(compute_matrix(A, 0, 1, 0.25), 2)])
    disp(['(Entropy+Diversity)_A = ', round(compute_matrix(A, 1, 1, 1), 2)])
    disp(['(Entropy+Diversity)^1/4_A = ', round(compute_matrix(A, 1, 1, 0.25), 2)])
    disp(' ');
    b=compute_matrix(A, 1, 0, 1)
    print('b is: '%b)